package com.yeskery.nut.util;

import java.io.File;
import java.lang.annotation.Annotation;
import java.net.JarURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
 * Class工具类
 * @author sprout
 * 2022-06-10 14:00
 */
public class ClassUtils {

    /** 日志对象 */
    private static final Logger logger = Logger.getLogger(ClassUtils.class.getName());

    /** file */
    private static final String PROTOCOL_FILE = "file";

    /** jar */
    private static final String PROTOCOL_JAR = "jar";

    /** class文件后缀 */
    private static final String CLASS_FILE_POSTFIX = ".class";

    /**
     * 私有化构造方法
     */
    private ClassUtils() {
    }

    /**
     * 通过类名称获取类对象
     * @param className class名称
     * @return 类对象，如果获取失败会返回<code>null</code>
     */
    public static Class<?> getClassByName(String className) {
        try {
            return Class.forName(className, false, ClassLoaderUtils.getClassLoader());
        } catch (ClassNotFoundException e) {
            return null;
        }
    }

    /**
     * 指定的class是否存在于当前classPath中
     * @param className class名称
     * @return 指定的class是否存在于当前classPath中
     */
    public static boolean isExistTargetClass(String className) {
        try {
            Class.forName(className, false, ClassLoaderUtils.getClassLoader());
            return true;
        } catch (ClassNotFoundException e) {
            return false;
        }
    }

    /**
     * 获取指定包名下的所有类
     * @param packageName 包名
     * @param isRecursive 是否递归查询
     * @param classFilterFunction class过滤器
     * @return 指定包名下的所有类
     */
    public static List<Class<?>> getClassList(String packageName, boolean isRecursive,
                                              Function<Class<?>, Boolean> classFilterFunction) {
        return getClassList(packageName, isRecursive).stream().filter(classFilterFunction::apply).collect(Collectors.toList());
    }

    /**
     * 获取指定包名下的所有类
     * @param packageName 包名
     * @param isRecursive 是否递归查询
     * @return 指定包名下的所有类
     */
    public static List<Class<?>> getClassList(String packageName, boolean isRecursive) {
        List<Class<?>> classList = new ArrayList<>();
        try {
            Enumeration<URL> urls = Thread.currentThread().getContextClassLoader().getResources(packageName.replaceAll("\\.", "/"));
            while (urls.hasMoreElements()) {
                URL url = urls.nextElement();
                if (url != null) {
                    String protocol = url.getProtocol();
                    if (protocol.equals(PROTOCOL_FILE)) {
                        String packagePath = url.getPath();
                        addClass(classList, packagePath, packageName, isRecursive);
                    } else if (protocol.equals(PROTOCOL_JAR)) {
                        JarURLConnection jarUrlConnection = (JarURLConnection) url.openConnection();
                        JarFile jarFile = jarUrlConnection.getJarFile();
                        Enumeration<JarEntry> jarEntries = jarFile.entries();
                        while (jarEntries.hasMoreElements()) {
                            JarEntry jarEntry = jarEntries.nextElement();
                            String jarEntryName = jarEntry.getName();
                            if (jarEntryName.endsWith(CLASS_FILE_POSTFIX)) {
                                String className = jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll("/", ".");
                                if (isRecursive || className.substring(0, className.lastIndexOf(".")).equals(packageName)) {
                                    classList.add(Class.forName(className, true, ClassLoaderUtils.getClassLoader()));
                                }
                            }
                        }
                    }
                }
            }
        } catch (Exception e) {
            logger.logp(Level.SEVERE, ClassUtils.class.getName(), "getClassList", "Search Package Class Fail.", e);
        }
        return classList;
    }

    /**
     * 获取指定包名下的所有类（可根据注解进行过滤）
     * @param packageName 包名
     * @param classFilter class过滤器
     * @param classFunction class容器提供器
     */
    public static void obtainClassListByFilter(String packageName, Function<Class<?>, Class<?>> classFilter,
                                                      BiFunction<Class<?>, Class<?>, Collection<Class<?>>> classFunction) {
        try {
            Enumeration<URL> urls = Thread.currentThread().getContextClassLoader().getResources(packageName.replaceAll("\\.", "/"));
            while (urls.hasMoreElements()) {
                URL url = urls.nextElement();
                if (url != null) {
                    String protocol = url.getProtocol();
                    if (protocol.equals(PROTOCOL_FILE)) {
                        String packagePath = url.getPath();
                        addClassByFilter(packagePath, packageName, classFilter, classFunction);
                    } else if (protocol.equals(PROTOCOL_JAR)) {
                        JarURLConnection jarUrlConnection = (JarURLConnection) url.openConnection();
                        JarFile jarFile = jarUrlConnection.getJarFile();
                        Enumeration<JarEntry> jarEntries = jarFile.entries();
                        while (jarEntries.hasMoreElements()) {
                            JarEntry jarEntry = jarEntries.nextElement();
                            String jarEntryName = jarEntry.getName();
                            if (jarEntryName.endsWith(CLASS_FILE_POSTFIX)) {
                                String className = jarEntryName.substring(0, jarEntryName.lastIndexOf(".")).replaceAll("/", ".");
                                if (className.startsWith(packageName)) {
                                    Class<?> cls = Class.forName(className, true, ClassLoaderUtils.getClassLoader());
                                    Class<?> clazz = classFilter.apply(cls);
                                    if (clazz != null) {
                                        classFunction.apply(cls, clazz).add(cls);
                                    }
                                }
                            }
                        }
                    }
                }
            }
        } catch (Exception e) {
            logger.logp(Level.SEVERE, ClassUtils.class.getName(), "getClassListByAnnotation", "Search Package Class Fail.", e);
        }
    }

    /**
     * 获取指定包名下的所有类（可根据注解进行过滤）
     * @param packageName 包名
     * @param annotationClass 需要查找的注解
     * @return 符合条件的所有类
     */
    public static List<Class<?>> getClassListByAnnotation(String packageName, Class<? extends Annotation> annotationClass) {
        List<Class<?>> classList = new ArrayList<>();
        obtainClassListByFilter(packageName, cls -> cls.isAnnotationPresent(annotationClass) ? annotationClass : null, (cls, clazz) -> classList);
        return classList;
    }

    /**
     * 添加class
     * @param classList class集合
     * @param packagePath 基础包名路径
     * @param packageName 基础包名
     * @param isRecursive 是否递归查询
     */
    private static void addClass(List<Class<?>> classList, String packagePath, String packageName, boolean isRecursive) {
        try {
            File[] files = getClassFiles(packagePath);
            if (files != null) {
                for (File file : files) {
                    String fileName = file.getName();
                    if (file.isFile()) {
                        String className = getClassName(packageName, fileName);
                        classList.add(Class.forName(className, true, ClassLoaderUtils.getClassLoader()));
                    } else {
                        if (isRecursive) {
                            String subPackagePath = getSubPackagePath(packagePath, fileName);
                            String subPackageName = getSubPackageName(packageName, fileName);
                            addClass(classList, subPackagePath, subPackageName, true);
                        }
                    }
                }
            }
        } catch (Exception e) {
            logger.logp(Level.SEVERE, ClassUtils.class.getName(), "addClass", "Search Package Class Fail.", e);
        }
    }

    /**
     * 获取Class文件对象
     * @param packagePath 基础包名路径
     * @return Class文件对象
     */
    private static File[] getClassFiles(String packagePath) {
        return new File(packagePath).listFiles(file -> (file.isFile()
                && file.getName().endsWith(".class")) || file.isDirectory());
    }

    /**
     * 获取类名称
     * @param packageName 基础包名
     * @param fileName 文件名
     * @return 类名称
     */
    private static String getClassName(String packageName, String fileName) {
        String className = fileName.substring(0, fileName.lastIndexOf("."));
        if (!StringUtils.isEmpty(packageName)) {
            className = packageName + "." + className;
        }
        return className;
    }

    /**
     * 获取子包路径
     * @param packagePath 基础包名路径
     * @param filePath 文件路径
     * @return 子包路径
     */
    private static String getSubPackagePath(String packagePath, String filePath) {
        String subPackagePath = filePath;
        if (!StringUtils.isEmpty(packagePath)) {
            subPackagePath = packagePath + "/" + subPackagePath;
        }
        return subPackagePath;
    }

    /**
     * 获取子包名称
     * @param packageName 基础包名
     * @param filePath 文件路径
     * @return 子包名称
     */
    private static String getSubPackageName(String packageName, String filePath) {
        String subPackageName = filePath;
        if (!StringUtils.isEmpty(packageName)) {
            subPackageName = packageName + "." + subPackageName;
        }
        return subPackageName;
    }

    /**
     * 根据注解进行过滤添加Class
     * @param classList class集合
     * @param packagePath 基础包名路径
     * @param packageName 基础包名
     * @param annotationClass 需要查找的注解
     */
    private static void addClassByAnnotation(List<Class<?>> classList, String packagePath, String packageName,
                                             Class<? extends Annotation> annotationClass) {
        addClassByFilter(packagePath, packageName, cls -> cls.isAnnotationPresent(annotationClass) ? annotationClass : null, (cls, clazz) -> classList);
    }

    /**
     * 根据注解进行过滤添加Class
     * @param packagePath 基础包名路径
     * @param packageName 基础包名
     * @param classFilter class过滤器
     * @param classFunction class容器提供器
     */
    private static void addClassByFilter(String packagePath, String packageName, Function<Class<?>, Class<?>> classFilter,
                                         BiFunction<Class<?>, Class<?>, Collection<Class<?>>> classFunction) {
        try {
            File[] files = getClassFiles(packagePath);
            if (files != null) {
                for (File file : files) {
                    String fileName = file.getName();
                    if (file.isFile()) {
                        String className = getClassName(packageName, fileName);
                        Class<?> cls = Class.forName(className, true, ClassLoaderUtils.getClassLoader());
                        Class<?> clazz = classFilter.apply(cls);
                        if (clazz != null) {
                            classFunction.apply(cls, clazz).add(cls);
                        }
                    } else {
                        String subPackagePath = getSubPackagePath(packagePath, fileName);
                        String subPackageName = getSubPackageName(packageName, fileName);
                        addClassByFilter(subPackagePath, subPackageName, classFilter, classFunction);
                    }
                }
            }
        } catch (Exception e) {
            logger.logp(Level.SEVERE, ClassUtils.class.getName(), "addClassByAnnotation", "Search Package Class Fail.", e);
        }
    }
}
